
#@title Self-guidance modules
import torch
import diffusers
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import math
from IPython.display import Image
from PIL import Image
import numpy as np
import os

from utils.utils import AttentionStore, search_sequence_numpy
from utils.SelfGuidanceEdits import SelfGuidanceEdits
from globals import *

class SelfGuidanceAttnProcessor:
    r"""
    Default processor for performing attention-related computations.
    """

    def __call__(
        self,
        attn: diffusers.models.attention_processor.Attention,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)

        ### SELF GUIDANCE
        scores_ = attention_probs.reshape(len(query), -1, *attention_probs.shape[1:]).mean(1)
        h=w=math.isqrt(scores_.shape[1])
        scores_ = scores_.reshape(len(scores_), h, w, -1)
        if SG_RES != scores_.shape[2]:
          scores_ = TF.resize(scores_.permute(0,3,1,2), SG_RES, antialias=True).permute(0,2,3,1)
        try:
        #   global save_aux
        #   if not save_aux:
          if not Globals.get('save_aux'):
            len_=len(attn._aux['attn'])
            del attn._aux['attn']
            attn._aux['attn'] = [None]*len_ + [scores_]
          else:
            attn._aux['attn'][-1] = attn._aux['attn'][-1].cpu()
            attn._aux['attn'].append(scores_)
        except:
          del attn._aux['attn']
          attn._aux = {'attn': [scores_]}
        ### END SELF GUIDANCE

        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

class SelfGuidanceAttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    def __init__(self, attn_name, attn_store: AttentionStore, model):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
        self.attn_name = attn_name
        self.res = -1
        self.attn_store = attn_store
        self.model = model

    def save_attn(self, scores, attn_words_list, hard=False, thresh=0.5):
    #   global prompt_text_ids, attn_path, step
      scores = scores.chunk(2)[1]
      for words in attn_words_list:
        word_ids = self.model.tokenizer(words, return_tensors='np')['input_ids']
        word_ids = word_ids[word_ids < 49406]
        # idxs = search_sequence_numpy(prompt_text_ids, word_ids)
        idxs = search_sequence_numpy(Globals.get('prompt_text_ids'), word_ids)
        image = scores[...,idxs]
        if hard:
          image = SelfGuidanceEdits._attn_diff_norm(image, hard=True, thresh=thresh)
        # image = SelfGuidanceEdits._attn_diff_norm(image, hard=hard, thresh=thresh)
        image = image.mean(-1).mean(0)
        image = 255 * image / image.max()
        image = image.unsqueeze(-1).expand(*image.shape, 3).cpu().detach()
        image = image.numpy().astype(np.uint8)
        image = Image.fromarray(image).resize((256, 256))
        # image.save(os.path.join(attn_path, f"step_{step}_{words}_{self.attn_name}_{self.res}_{'hard' if hard else 'soft'}.png"))
        image.save(os.path.join(Globals.get('attn_path'), f"step_{Globals.get('step')}_{words}_{self.attn_name}_{self.res}_{'hard' if hard else 'soft'}.png"))

    def __call__(
        self,
        attn: diffusers.models.attention_processor.Attention,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
    ):

        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        ### SELF GUIDANCE
        # global new_timestep, SG_RES, step, debug_steps, attn_words_list, DEBUG, save_aux, global_thresh, NO_OPT
        # if new_timestep:
        if Globals.get('new_timestep'):
          query_ = attn.head_to_batch_dim(query)
          key_ = attn.head_to_batch_dim(key)
          scores_ = attn.get_attention_scores(query_, key_, attention_mask)

        #   if NO_OPT:
          if Globals.get('NO_OPT'):
            # self.attn_store(scores_, True, self.attn_name.split('_')[0], step)
            self.attn_store(scores_, True, self.attn_name.split('_')[0], Globals.get('step'))
          
          scores_ = scores_.reshape(len(query), -1, *scores_.shape[1:]).mean(1)
          h=w=math.isqrt(scores_.shape[1])
          self.res = h
          scores_ = scores_.reshape(len(scores_), h, w, -1)
          if SG_RES != scores_.shape[2]:
            scores_ = TF.resize(scores_.permute(0,3,1,2), SG_RES, antialias=True).permute(0,2,3,1)
        #   if DEBUG and step in debug_steps:
          if Globals.get('DEBUG') and Globals.get('step') in debug_steps:
          # if DEBUG and step in debug_steps and self.attn_name in attn_block_list:
            self.save_attn(scores_, attn_words_list)
            self.save_attn(scores_, attn_words_list, hard=True, thresh=global_thresh)
          try:
            # if not save_aux:
            if not Globals.get('save_aux'):
              len_=len(attn._aux['attn'])
              del attn._aux['attn']
              attn._aux['attn'] = [None]*len_ + [scores_]
            else:
              attn._aux['attn'][-1] = attn._aux['attn'][-1].cpu()
              attn._aux['attn'].append(scores_)
          except:
            try: del attn._aux['attn']
            except: pass
            attn._aux = {'attn': [scores_]}
        else:
          query_ = attn.head_to_batch_dim(query)
          key_ = attn.head_to_batch_dim(key)
          scores_ = attn.get_attention_scores(query_, key_, attention_mask)

        #   if NO_OPT: 
          if Globals.get('NO_OPT'): 
            # self.attn_store(scores_, True, self.attn_name.split('_')[0], step)
            self.attn_store(scores_, True, self.attn_name.split('_')[0], Globals.get('step'))

          scores_ = scores_.reshape(len(query), -1, *scores_.shape[1:]).mean(1)
          h=w=math.isqrt(scores_.shape[1])
          self.res = h
          scores_ = scores_.reshape(len(scores_), h, w, -1)
          if SG_RES != scores_.shape[2]:
            scores_ = TF.resize(scores_.permute(0,3,1,2), SG_RES, antialias=True).permute(0,2,3,1)
        #   if DEBUG and step in debug_steps:
          if Globals.get('DEBUG') and Globals.get('step') in debug_steps:
          # if DEBUG and step in debug_steps and self.attn_name in attn_block_list:
            self.save_attn(scores_, attn_words_list)
            self.save_attn(scores_, attn_words_list, hard=True, thresh=global_thresh)
          try:
            # if not save_aux:
            if not Globals.get('save_aux'):
              len_=len(attn._aux['attn'])
              del attn._aux['attn']
              attn._aux['attn'] = [None]*(len_-1) + [scores_]
            else:
              attn._aux['attn'][-1] = scores_
          except:
            try: del attn._aux['attn']
            except: pass
            attn._aux = {'attn': [scores_]}

        ### END SELF GUIDANCE


        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

